import os
import claripy
from rex.exploit import Exploit
from povsim import CGCPovSimulator
import tempfile
import compilerex
from angr.state_plugins.trace_additions import ChallRespInfo

import logging
l = logging.getLogger("rex.exploit.cgc.cgc_exploit")


def _get_byte(var_name):
    ## XXX TODO FIXME DO NOT DO THIS HOLY SHIT WHAT THE FUCK
    ## see https://github.com/angr/angr/issues/922
    idx = var_name.split("_")[3]
    return int(idx, 16)


class CGCFormula(object):
    def __init__(self, solver):
        # the statement
        self.smt_stmt = None
        # map of variable names to boolector ids
        self.name_to_id = dict()
        # map of original variable names to cleaned names
        self.name_mapping = dict()
        # the name of the btor
        self.name = None

        self._clean_formula(solver)

    def _clean_formula(self, solver):
        formula = solver._get_solver().to_smt2()

        # replace occurences of stdin with byte variables
        bytes_seen = []
        for var_name in solver.variables:
            if var_name.startswith("aeg_input"):
                byte_name = "byte_%x" % _get_byte(var_name)
                bytes_seen.append(byte_name)
                formula = formula.replace(var_name, byte_name)
                self.name_mapping[var_name] = byte_name

        # remove check-sat, clean up string
        new_form = ""
        for line in formula.split("\n")[2:][:-2]:
            if "declare-fun" not in line:
                new_form += "\"%s\"\n" % (line + "\\n")

        # re-declare the variables, to enforce order
        fmt = "\"(declare-fun %s () (_ BitVec %d))\\n\""
        declarations = []
        for b in bytes_seen:
            declarations.append(fmt % (b, 8))
            self.name_to_id[b] = len(self.name_to_id) + 2

        # gather any other variables
        for line in formula.split("\n")[2:][:-2]:
            if "declare-fun" in line:
                variable = line.split()[1]
                if variable.startswith("byte"):
                    continue
                declarations.append("\"" + line + "\\n\"")
                self.name_mapping[variable] = variable
                self.name_to_id[variable] = len(self.name_to_id) + 2

        declarations = '\n'.join(declarations) + "\n"
        self.smt_stmt = declarations + new_form


class IntInfo(object):
    def __init__(self):
        self.start = None
        self.size = None
        self.base = None
        self.var_name = None


class CGCExploit(Exploit):
    '''
    A CGC exploit object, offers more flexibility than an Exploit object for
    the sake of the game.
    https://github.com/CyberGrandChallenge/cgc-release-documentation/blob/master/walk-throughs/understanding-cfe-povs.md
    '''

    def __init__(self, crash, cgc_type, bypasses_nx, bypasses_aslr):
        super(CGCExploit, self).__init__(crash, bypasses_nx, bypasses_aslr)

        self.cgc_type = cgc_type
        self.method_name = 'unclassified'

        # set by all exploits, represents the payload
        self._mem = None

        # arguments CGC infra will send us
        self._arg_vars = None
        # the payload stuff
        self._payload_len = None
        self._raw_payload = None
        # the length of the receiving buffer
        self._recv_buf_len = self.crash.state.solver.eval(self.crash.state.posix.stdout.size)
        # C code which handles challenge response
        self._solver_code = ""
        # the cleaned formulas
        self._formulas = []
        # the integer starts
        self._stdin_int_infos = {}
        self._stdout_int_infos = {}
        self._sorted_stdin_int_infos = []
        self._sorted_stdout_int_infos = []

    # DUMPING

    def dump(self, filename=None):
        '''
        dumps a CGC exploit.

        XXX: just python script for now.
        '''

        return self.dump_python(filename)

    def dump_python(self, filename=None):

        raise NotImplementedError(
            "It is the responsibility of subclasses to implement this method"
            )

    def dump_c(self, filename=None):

        raise NotImplementedError(
            "It is the responsibility of subclasses to implement this method"
            )

    def dump_binary(self, filename=None):
        c_code = self.dump_c()
        compiled_result = compilerex.compile_from_string(c_code)
        if filename is not None:
            with open(filename, 'wb') as f:
                f.write(compiled_result)
            os.chmod(filename, 0o755)
            return None

        return compiled_result

### GENERATION

    def _generate_formula(self, extra_vars_to_solve=None):
        """
        This function is used to generate the equations which are inserted inside C exploits
        """

        if extra_vars_to_solve is None:
            extra_vars_to_solve = []

        st = self.crash.state.copy()

        self._prepare_chall_resp(st)

        ft = st.solver._solver._merged_solver_for(lst=[self._mem]+extra_vars_to_solve)
        # filter out constants
        the_vars = set()
        split = ft.split()
        for solver in split:
            if len(solver.variables) > 1:
                the_vars.update(solver.variables)
        ft = st.solver._solver._merged_solver_for(names=the_vars)

        # try to use original crash_input instead of random ones if some input is still symbolic at this stage
        # TODO: use the reuse_input_constraints API in Exploit
        length = st.solver.eval(st.posix.stdin.size)
        constraints = []
        if hasattr(self.crash, "use_crash_input") and self.crash.use_crash_input:
            constraints = self.reuse_input_constraints(self.crash.sim_input, 0, self.crash.crash_input, state=st)

        self._payload_len = length
        self._raw_payload = st.posix.dumps(0, extra_constraints=constraints)

        self._create_solvers(ft, extra_vars_to_solve)

    def _create_solvers(self, ft, extra_vars_to_solve=None):
        split_solvers = ft.split()

        if extra_vars_to_solve is None:
            extra_vars_to_solve = []

        # make sure there is a chall_resp_info plugin
        if not self.crash.state.has_plugin("chall_resp_info"):
            # register a blank one
            self.crash.state.register_plugin("chall_resp_info", ChallRespInfo())

        # figure out start indices for all solvers
        stdin_solver = []
        for solver in split_solvers:
            stdin_indices = self._get_stdin_start_indices(solver)
            for idx, min_stdout_needed in stdin_indices:
                stdin_solver.append((idx, min_stdout_needed, solver))

        # get an extra solver for the extra_vars_to_solve
        merged_extra_solver = None
        if len(extra_vars_to_solve) > 0:
            extra_vars_to_solve = set(extra_vars_to_solve)
            important_solvers = [x for x in split_solvers if len(x.variables & extra_vars_to_solve) > 0]
            if len(important_solvers) > 0:
                merged_extra_solver = important_solvers[0]
                for s in important_solvers[1:]:
                    merged_extra_solver = merged_extra_solver.combine(s)

        # sort them
        stdin_solver = sorted(stdin_solver, key=lambda x: x[0])

        # get int nums
        self._stdin_int_infos = {}
        self._stdout_int_infos = {}
        for solver in split_solvers:
            for info in self._get_stdin_int_infos(solver):
                self._stdin_int_infos[info.var_name] = info
            for info in self._get_stdout_int_infos(solver):
                self._stdout_int_infos[info.var_name] = info

        # sort them
        self._sorted_stdin_int_infos = sorted(self._stdin_int_infos.values(), key=lambda x: x.start)
        self._sorted_stdout_int_infos = sorted(self._stdout_int_infos.values(), key=lambda x: x.start)

        # FIXME FLAG THIS WILL NEED TO BE CHANGED
        if extra_vars_to_solve is not None and len(extra_vars_to_solve) > 0:
            stdin_solver.append((self.crash.state.solver.eval(self.crash.state.posix.fd[0].read_pos),
                                 self.crash.state.solver.eval(self.crash.state.posix.fd[1].write_pos),
                                 merged_extra_solver))

        l.debug("There are %d solvers after splitting", len(stdin_solver))

        self._solver_code = ""
        for i, (min_stdin, min_stdout_needed, solver) in enumerate(stdin_solver):

            formula = CGCFormula(solver)
            self._formulas.append(formula)

            btor_name = "btor_%d" % i
            formula.name = btor_name
            solver_code = ""

            # possibly send more
            solver_code += self._create_send_stdin(min_stdin, min_stdout_needed)
            # we need to read until we get the bytes
            solver_code += self._create_read_bytes(min_stdout_needed)

            # now we have all the bytes we needed
            # parse the formula
            solver_code += self._create_boolector_parse(btor_name, formula)
            # constrain any "input" variables (regval, addr)
            solver_code += self._create_constrain_vals(solver, btor_name, formula)
            # add constraints to any stdin we've already sent
            solver_code += self._create_constrain_stdin(solver, btor_name, formula)
            # add constraints to stdout for the bytes we got
            solver_code += self._create_constrain_stdout(solver, btor_name, formula)
            # add constraints to any integers we have already used
            solver_code += self._create_constrain_integers(solver, btor_name, formula)
            # now create the byte setters
            solver_code += self._create_byte_setters(solver, btor_name, formula)
            self._solver_code += solver_code + "\n"

        # we are done
        l.debug("done creating solvers")


    @staticmethod
    def _make_c_int_arr(list_of_ints):
        str_list = [str(x) for x in list_of_ints]
        return "{" + ", ".join(str_list) + "}"

    def _get_stdin_int_infos(self, solver):
        chall_resp_info = self.crash.state.get_plugin("chall_resp_info")
        int_infos = []

        # get the indices from the ATOI nodes
        for v in solver.variables:
            if v.startswith("StrToInt"):
                stdin_indices = chall_resp_info.get_stdin_indices(v)
                if len(stdin_indices) > 0:
                    info = IntInfo()
                    info.var_name = v
                    info.start = min(stdin_indices)
                    info.size = max(stdin_indices)+1-min(stdin_indices)
                    info.base = int(v.split("_")[1], 10)
                    int_infos.append(info)
        return int_infos

    def _get_stdout_int_infos(self, solver):
        chall_resp_info = self.crash.state.get_plugin("chall_resp_info")
        int_infos = []

        # get the indices from the ATOI nodes
        for v in solver.variables:
            if v.startswith("IntToStr"):
                stdout_indices = chall_resp_info.get_stdout_indices(v)
                if len(stdout_indices) > 0:
                    info = IntInfo()
                    info.var_name = v
                    info.start = min(stdout_indices)
                    info.size = max(stdout_indices)+1-min(stdout_indices)
                    info.base = int(v.split("_")[1], 10)
                    int_infos.append(info)
        return int_infos

    def _get_stdin_start_indices(self, solver):
        # need for each stdin byte the min amount of stdout we saw before it
        # we order the solvers by min stdin first
        # A1 A2 where A2 needs more data, A2 will also have a different number of output bytes
        # if a solver has input bytes with different numbers of stdout, we solve, constrain, send, receive repeat

        # returns a list of (stdin_start_idx, min_stdout_seen)

        chall_resp_info = self.crash.state.get_plugin("chall_resp_info")

        # get the indices used by the solver
        stdin_indices = sorted(self._get_stdin_bytes(solver))
        stdout_indices = sorted(self._get_stdout_bytes(solver))

        # get the indices from the ATOI nodes
        for v in solver.variables:
            if v.startswith("StrToInt"):
                stdin_indices = sorted(stdin_indices + list(chall_resp_info.get_stdin_indices(v)))
            if v.startswith("IntToStr"):
                stdout_indices = sorted(stdout_indices + list(chall_resp_info.get_stdout_indices(v)))

        # return early if we don't care about stdout
        if len(stdout_indices) == 0:
            if len(stdin_indices) == 0:
                return [(self.crash.state.solver.eval(self.crash.state.posix.fd[0].read_pos), 0)]
            return [(min(stdin_indices), 0)]

        # now we want to group them by the stdout index
        # this is a mapping from stdin position to the stdout postition the first time a constraint was added on that byte
        stdin_min_stdout_constraints = dict(chall_resp_info.stdin_min_stdout_constraints)
        # this is a mapping from stdin position to the stdout position when the byte was read in
        stdin_min_stdout_reads = dict(chall_resp_info.stdin_min_stdout_reads)

        # if any stdin index isn't in the dictionary then we saw no constraints on it while stepping...
        # in this case we do not need to make a new solver when it changes
        # so lets propagate the min stdout value
        curr_stdout = -1
        for idx in stdin_indices:
            if idx in stdin_min_stdout_constraints:
                curr_stdout = stdin_min_stdout_constraints[idx]
            else:
                stdin_min_stdout_constraints[idx] = curr_stdout

        # if there are any indices that are not in the stdin_min_stdout_reads it's probably bad
        # but let's fill it in... just in case
        curr_stdout = -1
        for idx in stdin_indices:
            if idx in stdin_min_stdout_reads:
                curr_stdout = stdin_min_stdout_reads[idx]
            else:
                stdin_min_stdout_reads[idx] = curr_stdout

        # now we know every stdin_idx is in stdin_min_stdout, we can look for changes and say those are start indices
        # -2 is less than all of stdin_min_stdout so we will add the first idx to start_indices
        curr_stdout_constraint = -2
        curr_stdout_read = -2
        start_indices = []
        for idx in stdin_indices:
            if stdin_min_stdout_constraints[idx] != curr_stdout_constraint or \
                            stdin_min_stdout_reads[idx] != curr_stdout_read:
                curr_stdout_constraint = stdin_min_stdout_constraints[idx]
                curr_stdout_read = stdin_min_stdout_reads[idx]
                start_indices.append((idx, curr_stdout_constraint))

        return start_indices

    @staticmethod
    def _create_boolector_parse(btor_name, formula):
        c_code = ""
        smt_name = "smt_stmt_%s" % btor_name
        c_code += "  Btor *%s = boolector_new();\n" % btor_name
        c_code += '  boolector_set_opt(%s, "model_gen", 1);\n' % btor_name
        c_code += "  const char *%s = %s;\n" % (smt_name, formula.smt_stmt)
        c_code += "  boolector_parse(%s, %s, &error, &status);\n" % (btor_name, smt_name)
        c_code += "  if (error)\n"
        c_code += "    die(error);\n"

        return c_code

    def _create_constrain_vals(self, solver, btor_name, formula):
        # these are the values we need to get from c, could be t1vals, t2vals, etc
        c_code = ""
        for v in solver.variables:
            # skip stdout, stdin, flag
            if v.startswith("file_stdout") or v.startswith("aeg_input") or \
                    v.startswith("cgc-flag") or v.startswith("random"):
                continue
            # call constrain if it's an arg var
            if any(v in arg_var.variables for arg_var in self._arg_vars):
                func_name = "constrain_" + v
                c_code += "  %s(%s, %d);\n" % (func_name, btor_name, formula.name_to_id[v])
        return c_code

    def _create_byte_setters(self, solver, btor_name, formula):
        # generate byte setters for stdin
        set_bytes = ""
        for b in solver.variables:
            if b.startswith("aeg_input"):
                bid = formula.name_to_id[formula.name_mapping[b]]
                set_bytes += "  cur_byte = boolector_match_node_by_id(%s, %d);\n" % (btor_name, bid)
                set_bytes += "   payload[real_payload_off(%d)] " % _get_byte(b)
                set_bytes += "= to_char(boolector_bv_assignment(%s, cur_byte));\n" % btor_name

        # generate byte setters for strtoint inputs
        for v in solver.variables:
            if v.startswith("StrToInt"):
                stdin_indices = self.crash.state.get_plugin("chall_resp_info").get_stdin_indices(v)
                if len(stdin_indices) == 0:
                    continue
                bid = formula.name_to_id[formula.name_mapping[v]]
                base = int(v.split("_")[1], 10)
                intnum = self._sorted_stdin_int_infos.index(self._stdin_int_infos[v])
                set_bytes += "  set_payload_int_solve_result(%s, %d, %d, %d);\n" % (btor_name, bid, base, intnum)
        return set_bytes

    @staticmethod
    def _create_constrain_stdin(solver, btor_name, formula):
        # constrain all bytes of stdin we've already sent
        code = ""
        for v in solver.variables:
            if v.startswith("aeg_input"):
                byte_index = _get_byte(v)
                bid = formula.name_to_id[formula.name_mapping[v]]
                code += "  if (payload_off > %#x) {\n" % byte_index
                code += "    payload_val = boolector_unsigned_int(%s, payload[real_payload_off(%#x)], 8);\n" % (btor_name, byte_index)
                code += "    payload_val_var = boolector_match_node_by_id(%s, %d);\n" % (btor_name, bid)
                code += "    payload_con = boolector_eq(%s, payload_val_var, payload_val);\n" % btor_name
                code += "    boolector_assert(%s, payload_con);\n" % btor_name
                code += "  }\n"
        return code

    @staticmethod
    def _create_constrain_stdout(solver, btor_name, formula):
        code = ""
        for v in solver.variables:
            if v.startswith("file_stdout"):
                byte_index = int(v.split('_')[2], 16)
                bid = formula.name_to_id[formula.name_mapping[v]]
                code += "  if (recv_off > %#x) {\n" % byte_index
                code += "    stdout_val = boolector_unsigned_int(%s, received_data[real_recv_off(%#x)], 8);\n" % (btor_name, byte_index)
                code += "    stdout_val_var = boolector_match_node_by_id(%s, %d);\n" % (btor_name, bid)
                code += "    stdout_con = boolector_eq(%s, stdout_val_var, stdout_val);\n" % btor_name
                code += "    boolector_assert(%s, stdout_con);\n" % btor_name
                code += "  }\n"
        return code

    def _create_constrain_integers(self, solver, btor_name, formula):
        code = ""
        for v in solver.variables:
            if v.startswith("IntToStr"):
                stdout_indices = self.crash.state.get_plugin("chall_resp_info").get_stdout_indices(v)
                if len(stdout_indices) == 0:
                    continue
                bid = formula.name_to_id[formula.name_mapping[v]]
                start = min(stdout_indices)
                size = max(stdout_indices)-min(stdout_indices)+1
                base = int(v.split("_")[1], 10)
                code += "  if (recv_off > %#x) {\n" % start
                code += "    memset(temp_int_buf, 0, sizeof(temp_int_buf));\n"
                code += "    memcpy(temp_int_buf, received_data + real_recv_off(%#x), %#x+10);\n" % (start, size)
                code += "    temp_int = strtoul(temp_int_buf, NULL, %d);\n" % base
                code += "    // recv_int_corrections is already set when receiving\n"
                code += "    int_val = boolector_unsigned_int(%s, temp_int, 32);\n" % btor_name
                code += "    int_val_var = boolector_match_node_by_id(%s, %d);\n" % (btor_name, bid)
                code += "    int_con = boolector_eq(%s, int_val_var, int_val);\n" % btor_name
                code += "    boolector_assert(%s, int_con);\n" % btor_name
                code += "  }\n"

        for v in solver.variables:
            if v.startswith("StrToInt"):
                stdin_indices = self.crash.state.get_plugin("chall_resp_info").get_stdin_indices(v)
                if len(stdin_indices) == 0:
                    continue
                bid = formula.name_to_id[formula.name_mapping[v]]
                start = min(stdin_indices)
                size = max(stdin_indices)-min(stdin_indices)+1
                base = int(v.split("_")[1], 10)
                code += "  if (payload_off > %#x) {\n" % start
                code += "    memset(temp_int_buf, 0, sizeof(temp_int_buf));\n"
                code += "    memcpy(temp_int_buf, payload + %#x, %#x+10);\n" % (start, size)
                code += "    temp_int = strtoul(temp_int_buf, NULL, %d);\n" % base
                code += "    // payload int corrections don't need to happen here either they happen when solving\n"
                code += "    int_val = boolector_unsigned_int(%s, temp_int, 32);\n" % btor_name
                code += "    int_val_var = boolector_match_node_by_id(%s, %d);\n" % (btor_name, bid)
                code += "    int_con = boolector_eq(%s, int_val_var, int_val);\n" % btor_name
                code += "    boolector_assert(%s, int_con);\n" % btor_name
                code += "  }\n"

        code += "  if (boolector_sat(%s) != 10){\n" % btor_name
        code += "    die(\"unsat\\n\");\n"
        code += "  }\n\n"
        return code

    def _create_read_bytes(self, end_idx):
        if end_idx == 0:
            return ""
        code = "  recv_extra_for_int = check_for_recv_extra(recv_off, %#x - recv_off);\n" % end_idx
        code += "  if (recv_off < %#x) {\n" % end_idx
        code += "    recv_amount = receive_n_timeout(0, received_data+real_recv_off(recv_off), %#x - recv_off+recv_extra_for_int, 300000);\n" % end_idx
        code += "    fake_recv_amount = fixup_recv_amount(recv_off, recv_amount);\n"
        code += "    // realloc the buffer if necessary\n"
        code += "    if (fake_recv_amount > recv_amount) {\n"
        code += "      old_recv_buf_len = recv_buf_len;\n"
        code += "      recv_buf_len += fake_recv_amount - recv_amount;\n"
        code += "      received_data = realloc_zero(received_data, old_recv_buf_len, recv_buf_len);\n"
        code += "    }\n"
        code += "    recv_off += fake_recv_amount;\n"
        code += "  }\n"
        return code

    @staticmethod
    def _create_send_stdin(num_bytes, min_stdout_needed):
        # we send if we need to recv more
        if min_stdout_needed == 0:
            return ""
        code = "  if (recv_off < %#x) {\n" % min_stdout_needed
        code += "    send_amount = real_payload_off(payload_off+%#x) - real_payload_off(payload_off);\n" % num_bytes
        code += "    send_all(1, payload+real_payload_off(payload_off), send_amount);\n"
        code += "    payload_off += %#x;\n" % num_bytes
        code += "  }\n"
        return code

    @staticmethod
    def _get_stdout_bytes(solver):
        byte_indices = set()
        for v in solver.variables:
            if v.startswith("file_stdout_"):
                byte_index = int(v.split("_")[2], 16)
                byte_indices.add(byte_index)
        return byte_indices

    @staticmethod
    def _get_stdin_bytes(solver):
        byte_indices = set()
        for v in solver.variables:
            if v.startswith("aeg_input"):
                byte_index = _get_byte(v)
                byte_indices.add(byte_index)
        return byte_indices

    def _prepare_chall_resp(self, state):
        # now we need to find the challenge response stuff
        # first break constraints at And's
        constraints = []
        for c in state.solver.constraints:
            if c.op == "And":
                constraints.extend(c.args)
            else:
                constraints.append(c)

        # filter for possible flag constraints
        filtered_constraints = []
        for c in constraints:
            if any(v.startswith("cgc-flag") or v.startswith("random") for v in c.variables):
                filtered_constraints.append(c)

        self.filter_uncontrolled_constraints(state)
        # now separate into constraints we can probably control, and those we can't

        controllable_constraints = []
        uncontrollable_constraints = []
        if not state.has_plugin("chall_resp_info"):
            # register a blank one
            state.register_plugin("chall_resp_info", ChallRespInfo())

        chall_resp_info = state.get_plugin("chall_resp_info")

        for c in filtered_constraints:
            if any(v.startswith("aeg_input") for v in c.variables) or \
                    any(v in chall_resp_info.vars_we_added for v in c.variables):
                controllable_constraints.append(c)
            elif any(v.startswith("output_var") for v in c.variables):
                # an output like a leak
                pass
            else:
                # uncontrollable constraints will show up as zen constraints etc
                uncontrollable_constraints.append(c)

        if len(controllable_constraints) > 0:
            l.debug("Challenge response detected!")
            file_1 = state.posix.stdout
            stdout = file_1.load(0, file_1.size)

            stdout_len = state.solver.eval(file_1.size)
            stdout_bvs = [claripy.BVS("file_stdout_%#x" % i, 8, explicit_name=True) for i in range(stdout_len)]
            stdout_bv = claripy.Concat(*stdout_bvs)

            state.add_constraints(stdout == stdout_bv)
            # we call simplify to separate the contraints/dependencies
            state.solver.simplify()

            merged_solver = state.solver._solver._merged_solver_for(lst=[self._mem] + controllable_constraints)
            # todo here we can verify that there are actually stdout bytes here, otherwise we have little hope

            # add the important stdout vars to mem
            needed_vars = []
            for bv in stdout_bvs:
                if len(bv.variables & merged_solver.variables) != 0:
                    needed_vars.append(bv)

            # add the str_to_int vars and int_to_str vars
            for _, v in chall_resp_info.str_to_int_pairs:
                needed_vars.append(v)
            for v, _ in chall_resp_info.int_to_str_pairs:
                needed_vars.append(v)
            self._mem = claripy.Concat(self._mem, *needed_vars)

    @staticmethod
    def filter_uncontrolled_constraints(state):
        # remove constraints from the state which involve only the flagpage or random
        # this solves a problem with CROMU_00070, where the floating point
        # operations have to be done concretely and constrain the flagpage
        # to being a single value
        # we do not remove zen constraints
        zen_cache_keys = set(x.cache_key for x in state.get_plugin("zen_plugin").zen_constraints)
        new_cons = [ ]
        for con in state.solver.constraints:
            if con.cache_key in zen_cache_keys or \
                    not all(v.startswith("cgc-flag") or v.startswith("random") for v in con.variables):
                new_cons.append(con)

        state.release_plugin('solver')
        state.add_constraints(*new_cons)
        state.downsize()

        state.solver.simplify()
        state.solver._solver.result = None

    # TESTING

    def test_binary(self, enable_randomness=True, times=1, timeout=15):
        """
        Test the binary generated
        """

        # dump the binary code
        pov_binary_filename = tempfile.mktemp(dir='/tmp', prefix='rex-pov-')
        self.dump_binary(filename=pov_binary_filename)
        os.chmod(pov_binary_filename, 0o755)

        pov_tester = CGCPovSimulator()
        result = pov_tester.test_binary_pov(
                pov_binary_filename,
                self.crash.binary,
                enable_randomness=enable_randomness,
                timeout=timeout,
                times=times)

        # remove the generated pov
        os.remove(pov_binary_filename)

        return result
